#! /usr/bin/python3
import os
from multiprocessing import Process, freeze_support, Manager as mng

os.system('sudo chmod 777 %s' % str(__file__))

def main(d:mng().dict):
    from sys import argv, exit
    from PyQt5.QtWidgets import QApplication, QMainWindow
    from PyQt5 import QtCore, QtGui
    app = QApplication(argv)
    from packages.yolov5 import data_train
    from res.ui import Ui_Yolov5TrainGuide as MAIN
    from PyQt5.QtCore import QTimer
    from PyQt5.QtGui import QIcon
    from PyQt5 import QtWidgets
    import pickle

    this_dir = str(os.path.realpath(__file__)).replace(
        str(os.path.realpath(__file__)).replace('\\', '/').split('/')[-1], '')

    # 主窗口
    #######################################################################################################################
    class MainWindow(QMainWindow, MAIN):

        # 初始化函数
        ###################################################################################################################
        def __init__(self, parent=None, train_fun=None, mpdict = None):

            super(MainWindow, self).__init__(parent)
            super(MainWindow, self).setupUi(self)
            # loadUi('main.ui', self)

            # 固定窗口大小
            self.setFixedSize(self.size())

            # 加载图标
            self.loadIcon()

            gpunum = data_train.device_count()
            self.gpuNum.setMaximum(gpunum)
            self.gpuNum.setValue(gpunum)

            self.timer = QTimer()  # 定时器，或者计时器， whatever
            self.timer.start(1)  # 开始每200毫秒自动运行下一行连接的函数，一般用于更新界面，根据需求一般设置为10-20毫秒左右
            self.timer.timeout.connect(self.timeoutFun)  # 每2000毫秒自动运行一次的函数

            # 激活全部按钮、菜单选项、下拉列表用于测试，实际使用时注释掉
            self.create_dir.setEnabled(True)
            self.prepare.setEnabled(True)
            self.generate_code.setEnabled(True)
            self.choose_model.setEnabled(False)

            self.model_size.setEnabled(True)
            self.envName.setReadOnly(True)

            # 事件连接函数
            self.create_dir.clicked.connect(self.create_dir_ClickFun)
            self.prepare.clicked.connect(self.prepare_ClickFun)
            self.generate_code.clicked.connect(self.generate_code_ClickFun)
            self.browse.clicked.connect(self.browse_ClickFun)
            self.choose_model.clicked.connect(self.choose_model_ClickFun)
            self.startTrain.clicked.connect(self.startTrain_ClickFun)
            self.terminateTrain.clicked.connect(self.terminateTrain_ClickFun)
            self.browseConda.clicked.connect(self.browseConda_ClickFun)

            self.default_model.clicked.connect(self.enable_choose)
            self.other_model.clicked.connect(self.enable_choose)

            self.model_size.currentIndexChanged.connect(self.model_size_ClickFun)

            self.closeEvent = self.close_Fun

            # train_data_preparation
            self.train = None
            self.data_train = train_fun
            self.model_weight_file = None
            self.msgShow = ['']
            self.isTraining = False
            self.d = mpdict



            self.load_lastdata()

            # subprocess
            self._process = QtCore.QProcess(self)
            self._process.setProcessChannelMode(QtCore.QProcess.MergedChannels)
            self._process.readyReadStandardOutput.connect(self.on_readReady)

            self.textBrowser.setOpenExternalLinks(True)
            self.writeToMsg(
                '<h3>Welcome Using YOLOv5 Training Guide Tool! <a href="https://www.bitlsh.top">Click Here</a> to visit author\'s personal website.</h3>')

        def writeToCMD(self, command):
            command += '\n'
            self._process.write(command.encode())

        def writeToMsg(self, msg):
            if not ('<' in msg and '>' in msg):
                msg = msg.replace(' ','&nbsp;')
                msg = msg.replace('<','&lt;')
            msg = msg.replace('\n','<br>').replace('\t','<br>').replace('\r','<br>')
            msg += '<br>'
            if 'mAP@.5' in msg:
                if 'mAP@.5' in self.msgShow[-1]:
                    self.msgShow[-1] = msg.split(']')[-2] + ']'
                else:
                    self.msgShow.append(msg)
            elif ('/%s' % str(self.epochs.value() - 1)) in msg:
                if ('/%s' % str(self.epochs.value() - 1)) in self.msgShow[-1]:
                    if not '[00:00<?,&nbsp;?it/s]' in msg:
                        self.msgShow[-1] = msg.split(']')[-2] + ']'
                else:
                    self.msgShow.append(msg)
            else:
                write = True
                notshow = []
                for forbid in notshow:
                    if forbid in msg:
                        write = False
                        break
                if 'Epoch' in msg and '[00:00&lt;?,&nbsp;?it/s]' in msg:
                    msg = '    Epoch   gpu_mem    box          obj           cls         labels   img_size  [00:00&lt;?, ?it/s]<br>'.replace(' ','&nbsp;')
                    self.msgShow.append(msg) if write else None
                elif '[00:00&lt;?,&nbsp;?it/s]' in msg and '0%' in msg and not 'mAP@.5' in self.msgShow[-1]:
                    pass
                else:
                    if 'This yolov5 training complete.' in msg:
                        self.terminateTrain_ClickFun()
                    self.msgShow.append(msg) if write else None

            nowmsg = ''
            for line in self.msgShow:
                nowmsg += line + '\n'

            self.textBrowser.setText(nowmsg)
            self.textBrowser.ensureCursorVisible()

            self.textBrowser.moveCursor(self.textBrowser.textCursor().End)

        @QtCore.pyqtSlot()
        def on_readReady(self):
            this_ret = bytes(self._process.readAllStandardOutput()).decode()
            # print(this_ret)
            self.writeToMsg(this_ret)

        def startTrain_ClickFun(self):
            print('startTrain')
            all_command = ''
            # self._process.start("bash")
            useconda = False
            if self.useCondaEnv.isChecked():
                if len(self.envName.text()):
                    useconda = True
                    condalocation = self.envName.text()
                    if not condalocation.endswith('/'):
                        condalocation += '/'
                    print('activate conda')

            interpreter = ('%sbin/python3' % condalocation) if useconda else 'python3'
            location = self.location.text()
            if location.startswith("~"):
                location = "/home/" + os.popen("echo $USER").read().strip("\n") + location[1:]
            if not location.endswith('/'):
                location += '/'
            print('change dir')
            # self.writeToCMD('cd %s%s' % (location, self.data_name.text()))
            all_command += 'cd %s%s\n' % (location, self.data_name.text())
            if self.gpuNum.value() > 1:
                print('multi gpu')
                # self.writeToCMD('python3 -m torch.distributed.launch --master_port 12345 --nproc_per_node %s start_train.py' % str(self.gpuNum.value()))
                all_command += '%s -m torch.distributed.launch --master_port 12345 --nproc_per_node %s start_train.py\n' % (
                interpreter, str(self.gpuNum.value()))
            else:
                print('single gpu')
                all_command += '%s start_train.py\necho This yolov5 training complete.\n' % interpreter
                # self.writeToCMD('python3 start_train.py')
            open('traincommand', 'w').write(all_command)
            os.system('chmod +x traincommand')
            self._process.start("bash")
            self.writeToCMD('./traincommand')
            self.isTraining = True

        def terminateTrain_ClickFun(self):
            self._process.close()
            string = os.popen('nvidia-smi').read().split('\n')
            pids = []
            for line in string:
                if 'python3' in line:
                    line = line.split(' ')
                    data = []
                    [data.append(data0) if len(data0) else None for data0 in line]
                    # print(data)
                    pids.append(data[4])
            # print(pids)
            for pid in pids:
                os.system('kill %s' % pid)
            self.writeToMsg('training progress terminated') if len(pids) else None
            self.isTraining = False

        def enable_choose(self):
            self.choose_model.setEnabled(self.other_model.isChecked())

        # 按钮, 选择训练的pt文件
        def choose_model_ClickFun(self):
            directory = QtWidgets.QFileDialog.getOpenFileName(None,
                                                              "Choose File",
                                                              "/home/" + os.popen("echo $USER").read().strip("\n"),
                                                              "pt Files (*.pt);;All Files (*)")
            if os.path.exists(directory[0]):
                self.model_weight_file = directory[0]
                self.writeToMsg("Weight file " + self.model_weight_file + " is choosen to be trained.")

        # 选择存放数据文件夹的父文件夹
        def browse_ClickFun(self):
            print('browse now')
            directory = QtWidgets.QFileDialog.getExistingDirectory(None,
                                                                   "Choose Location",
                                                                   "/home/" + os.popen("echo $USER").read().strip(
                                                                       "\n"))  # 起始路径
            self.location.setText(directory) if len(directory) else None

        def browseConda_ClickFun(self):
            print('browse conda now')
            directory = QtWidgets.QFileDialog.getExistingDirectory(None,
                                                                   "Choose Location",
                                                                   "/home/" + os.popen("echo $USER").read().strip(
                                                                       "\n"))  # 起始路径
            if os.path.exists('%s/bin/python3' % directory):
                self.envName.setText(directory) if len(directory) else None
            else:
                self.writeToMsg('Interpreter %s/bin/python3 doesn''t exist!' % directory)

        # 创建存放数据的文件夹
        def create_dir_ClickFun(self):
            print("你按了 " + self.create_dir.text() + " 这个按钮")

            self.train = self.data_train.Preparation(address=self.location.text(), dir_name=self.data_name.text(),
                                                     classes=('a', 'b'))
            self.train.create_data_dir()

            self.writeToMsg(
                "Data dir has been created in " + self.location.text() + ", please read file READ_ME.txt in dir " + self.data_name.text())
            self.location.setEnabled(False)
            self.data_name.setEnabled(False)

        # 创建用于训练的数据
        def prepare_ClickFun(self):
            print("你按了 " + self.prepare.text() + " 这个按钮")

            data_send = (self.train, (self.location.text(), self.data_name.text(), self.per.value()))
            self.d['data'] = data_send


        def load_lastdata(self):
            if os.path.exists('res/lastdata'):
                location, dataName, per, \
                modelSize, batchSize, epochs, \
                gpuNum, defaultModel, otherModel, \
                modelWeightFile, useConda, envName = pickle.load(open('res/lastdata', 'rb'))
                self.location.setText(location)
                self.data_name.setText(dataName)
                self.per.setValue(per)
                self.model_size.setCurrentIndex(modelSize)
                self.batch_size.setValue(batchSize)
                self.epochs.setValue(epochs)
                self.gpuNum.setValue(gpuNum)
                self.default_model.setChecked(defaultModel)
                self.other_model.setChecked(otherModel)
                self.model_weight_file = modelWeightFile
                self.useCondaEnv.setChecked(useConda)
                self.envName.setText(envName)

        def generate_code_ClickFun(self):
            print("你按了 " + self.generate_code.text() + " 这个按钮")
            location = self.location.text()

            if location.startswith("~"):
                location = "/home/" + os.popen("echo $USER").read().strip("\n") + location[1:]
            if not location.endswith("/") or location.endswith("\\"):
                location += "/"

            if self.train is None:
                labels = []
                for label in open(location + self.data_name.text() + "/label.txt").readlines():
                    labels.append(str(label).rstrip("\n"))
                labels = tuple(labels)
                self.train = self.data_train.Preparation(address=self.location.text(), dir_name=self.data_name.text(),
                                                         classes=labels)

            print(self.model_size.currentText().split('(')[1].split(')')[0], self.batch_size.value(),
                  self.epochs.value())

            option = self.train.get_train_options(net_type=self.model_size.currentText().split('(')[1].split(')')[0],
                                                  epochs=self.epochs.value(),
                                                  batch_size=self.batch_size.value(),
                                                  othermodel=self.model_weight_file if self.other_model.isChecked() else None)

            # print(type(self.data_name.text()))
            pickle.dump(option, open(location + self.data_name.text() + "/option", 'wb'))

            if os.path.exists(option.default_hyp):
                hyp_file = open(location + self.data_name.text() + "/hyp.yaml", 'w')
                hyp_file.write(open(option.default_hyp).read())
                hyp_file.close()

            codetext = open(this_dir + 'res/TrainCodeModel.py').read().replace("$MODULE_PATH", this_dir[:-1])
            run_file = open(location + self.data_name.text() + "/start_train.py", 'w')
            run_file.write(codetext)
            run_file.close()

            show_string = "Done.<br>" \
                          "Now you can open " + location + self.data_name.text() + "/hyp.yaml and tune hyperparameters.<br>" \
                          "Then open Terminal and input the following commands to start training.<br><br>" \
                          "cd " + location + self.data_name.text() + "<br>" \
                          "python3 start_train.py    # for single GPU<br>" \
                          "python3 -m torch.distributed.launch --nproc_per_node %s start_train.py    # for multi-GPU" % str(self.gpuNum.value())
            self.writeToMsg(show_string)

        # 下拉列表
        def model_size_ClickFun(self):
            print("你将该下拉列表选项变成了 " + self.model_size.currentText())

        # 自动运行的函数
        def timeoutFun(self):
            if self.d['msg'] is not None:
                self.writeToMsg(self.d['msg'])
                self.d['msg'] = None

            location = self.location.text()
            if location.startswith("~"):
                location = "/home/" + os.popen("echo $USER").read().strip("\n") + location[1:]
                # print(location)
            if bool(len(location)) and not location.endswith("/") or location.endswith("\\"):
                location += "/"

            flag1 = os.path.exists(location)
            flag2 = bool(len(self.data_name.text()))

            create_dir = False
            prepare = False
            generate = False

            if flag1 and flag2:
                if not os.path.exists(location + self.data_name.text()):
                    create_dir = True
                    self.create_dir.setText('Create Dir')
                elif not os.path.exists(location + self.data_name.text() + '/READ_ME.txt'):
                    create_dir = True
                    self.create_dir.setText('Complete Dir')
                if os.path.exists(location + self.data_name.text()):
                    if not self.location.isEnabled():
                        if os.path.exists(location + self.data_name.text() + "/label.txt"):
                            prepare = True
                    elif os.path.exists(location + self.data_name.text() + "/READ_ME.txt"):
                        if os.path.exists(location + self.data_name.text() + "/label.txt"):
                            prepare = True
                    if os.path.exists(location + self.data_name.text() + "/anchors.txt"):
                        generate = True
                        if self.other_model.isChecked() and self.model_weight_file is None:
                            generate = False
            self.create_dir.setEnabled(create_dir)
            self.prepare.setEnabled(prepare)
            self.generate_code.setEnabled(generate)

            if not self.envName.isEnabled() == self.useCondaEnv.isChecked():
                self.envName.setEnabled(self.useCondaEnv.isChecked())
            if not self.browseConda.isEnabled() == self.useCondaEnv.isChecked():
                self.browseConda.setEnabled(self.useCondaEnv.isChecked())

            if os.path.isdir('%s%s' % (location, self.data_name.text())):
                if len(self.data_name.text()) and os.path.isfile(
                        '%s%s/start_train.py' % (location, self.data_name.text())) and not self.isTraining:
                    if not self.startTrain.isEnabled():
                        self.startTrain.setEnabled(True)
                elif self.startTrain.isEnabled():
                    self.startTrain.setEnabled(False)
            elif self.startTrain.isEnabled():
                self.startTrain.setEnabled(False)

            if self.d['return'] is not None:
                self.train = self.d['return']
                self.d['return'] = None


            if self.isTraining:
                self.startTrain.setEnabled(False) if self.startTrain.isEnabled() else None
                self.browseConda.setEnabled(False) if self.browseConda.isEnabled() else None
                self.useCondaEnv.setEnabled(False) if self.useCondaEnv.isEnabled() else None
                self.envName.setEnabled(False) if self.envName.isEnabled() else None
                self.groupBox.setEnabled(False) if self.groupBox.isEnabled() else None
                self.groupBox_2.setEnabled(False) if self.groupBox_2.isEnabled() else None
                self.terminateTrain.setEnabled(True) if not self.terminateTrain.isEnabled() else None
            elif self.d['is_preparing']:
                self.startTrain.setEnabled(False) if self.startTrain.isEnabled() else None
                self.browseConda.setEnabled(False) if self.browseConda.isEnabled() else None
                self.useCondaEnv.setEnabled(False) if self.useCondaEnv.isEnabled() else None
                self.envName.setEnabled(False) if self.envName.isEnabled() else None
                self.groupBox.setEnabled(False) if self.groupBox.isEnabled() else None
                self.groupBox_2.setEnabled(False) if self.groupBox_2.isEnabled() else None
                self.terminateTrain.setEnabled(False) if self.terminateTrain.isEnabled() else None
            else:
                self.browseConda.setEnabled(
                    True) if not self.browseConda.isEnabled() and self.useCondaEnv.isChecked() else None
                self.useCondaEnv.setEnabled(True) if not self.useCondaEnv.isEnabled() else None
                self.envName.setEnabled(True) if not self.envName.isEnabled() and self.useCondaEnv.isChecked() else None
                self.groupBox.setEnabled(True) if not self.groupBox.isEnabled() else None
                self.groupBox_2.setEnabled(True) if not self.groupBox_2.isEnabled() else None
                self.terminateTrain.setEnabled(False) if self.terminateTrain.isEnabled() else None

        # 加载图标
        def loadIcon(self):
            if os.path.exists('icons/icon.ico'):
                self.setWindowIcon(QIcon('icons/icon.ico'))
            else:
                for imgtype in ['png', 'jpg', 'bmp', 'gif', 'jpeg']:
                    if os.path.exists('icons/icon.%s' % imgtype):
                        from packages import img2ico
                        img2ico('icons/icon.%s' % imgtype)
                        self.setWindowIcon(QIcon('icons/icon.ico'))

        def close_Fun(self, e):
            location = self.location.text()
            dataName = self.data_name.text()
            per = self.per.value()
            modelSize = self.model_size.currentIndex()
            batchSize = self.batch_size.value()
            epochs = self.epochs.value()
            gpuNum = self.gpuNum.value()
            defaultModel = self.default_model.isChecked()
            otherModel = self.other_model.isChecked()
            modelWeightFile = self.model_weight_file
            useConda = self.useCondaEnv.isChecked()
            envName = self.envName.text()

            savedata = (location, dataName, per, modelSize, batchSize, epochs, gpuNum,
                        defaultModel, otherModel, modelWeightFile, useConda, envName)
            pickle.dump(savedata, open('res/lastdata', 'wb'))
            os.system('kill %s' % self.d['pid'])
            print(self.d['pid'])
            self.terminateTrain_ClickFun()
            e.accept()

    # print(4)
    w = MainWindow(train_fun=data_train, mpdict=d)
    w.show()
    exit(app.exec())


# 将计算anchor放入另一个进程中计算防止界面卡死
def anchors_calculater(d:mng().dict):
    from packages.yolov5 import data_train
    d['pid'] = str(os.getpid())
    print('anchors calculating program is in a subprocess which pid is %s' % str(os.getpid()))
    while True:
        if d['data'] is not None:
            train, data = d['data']
            location, data_name, per = data
            d['is_preparing'] = True

            if train is None:
                train = data_train.Preparation(address=location, dir_name=data_name,
                                                         classes=('a', 'b'))
            ret, msg = train.check_data()
            if not ret:
                d['msg'] = '[ERROR]:' + msg
            else:
                while d['msg'] is not None:
                    pass
                d['msg'] = '############## - start generating data - ##############<br>'
                if location.startswith("~"):
                    location = "/home/" + os.popen("echo $USER").read().strip("\n") + location[1:]
                if not location.endswith("/") or location.endswith("\\"):
                    location += "/"
                labels = []
                for label in open(location + data_name + "/label.txt").readlines():
                    labels.append(str(label).rstrip("\n"))
                labels = tuple(labels)
                train.classes = labels

                while d['msg'] is not None:
                    pass
                d['msg'] = 'splitting data into training data and test data'
                # print(self.per.value()/100)
                train.split_data(train_percent=per / 100)  # 将数据集分为训练集和验证集
                while d['msg'] is not None:
                    pass
                d['msg'] = 'transforming xml label into txt label'
                train.create_voc_label()  # 将xml文件转化为yolo能够识别的格式的数据并保存为txt文件于label文件夹中
                while d['msg'] is not None:
                    pass
                d['msg'] = 'creating data yaml'
                train.create_yaml()  # 在主目录中创建含有数据集信息的yaml文件
                while d['msg'] is not None:
                    pass
                d['msg'] = 'calculating anchors(this may take a long time, please wait patiently)'
                train.clauculate_anchors(loop=1, d=d)  # 计算最佳anchors并保存在主目录anchors.txt中
                while d['msg'] is not None:
                    pass
                d['msg'] = 'creating cfg yaml'
                train.create_cfg()  # yolo模型文件
                while d['msg'] is not None:
                    pass
                d['msg'] = '<br>############### - end generating data - ###############'
                d['data'] = None
                d['return'] = train
                d['is_preparing'] = False


if __name__ == '__main__':
    freeze_support()

    msg = mng().dict()
    msg['msg'] = None
    msg['data'] = None
    msg['pid'] = None
    msg['is_preparing'] = False
    msg['return'] = None

    processes = [Process(target=main, args=(msg, )),
                 Process(target=anchors_calculater, args=(msg, ))]

    [process.start() for process in processes]
    [process.join() for process in processes]
